import argparse
import numpy as np
import time

from environment import MovieLensEnv
from agent import create_agent

def save_checkpoint(round, regrets, times, **kw):

    T = round
    N = kw['N']
    K = kw['K']
    d = kw['d']
    lam = kw['lam']
    seed = kw['seed']
    algorithm = kw['algorithm']
    oracle_type = kw['oracle_type']
    B = kw['B']
    C = kw['C']
    
    with open(f'./results-movielens/checkpoint/alg={algorithm}_T={T}_N={N}_K={K}_d={d}_lam={lam}_seed={seed}_oracle={oracle_type}_B={B}.npy', 'wb') as f:
        np.save(f, regrets)
    
    with open(f'./results-movielens/times/alg={algorithm}_T={T}_N={N}_K={K}_d={d}_lam={lam}_seed={seed}_oracle={oracle_type}_B={B}.npy', 'wb') as t:
        np.save(t, times)

def run_experiment(**kw):
    
    env = MovieLensEnv(**kw)
    agent = create_agent(**kw)

    regrets = [0.0]
    times = [0.0]

    T = kw['T']
    N = kw['N']
    K = kw['K']
    d = kw['d']
    lam = kw['lam']
    seed = kw['seed']
    algorithm = kw['algorithm']
    oracle_type = kw['oracle_type']
    B = kw['B']
    C = kw['C']

    sum_regret = 0

    total_start = time.perf_counter()

    print(agent.kappa)
    print(agent.beta)

    
    for t in range(1, T+1):

        env.new(t)

        X = env.get_features()
        
        agent.observe(X, t)
        
        recc = env.oracle(agent.U)

        action = agent.choose_action(recc)
        
        Y, stop, regret = env.play(t, action) # Y: (stop+1, 2)

    
        observedX = X[action].reshape(env.K, 1, -1)[:(stop+1)] # X[action].reshape(K, 1, -1): (K, 1, d)

        agent.update(observedX, Y, stop)

        end = time.perf_counter()
        elapsed_time_ms = (end - total_start) * 1000
        times.append(elapsed_time_ms)

        regrets.append(regret)
        sum_regret += regret

        if t % 500 == 0:
            print(f'Round:{t}, AGENT ACTION:{action}, OPT ACTION:{env._get_OPT(t)}, regret={sum_regret}')
            print(f'TIME: {times[-1]:.3f}')

    print(f'AGENT beta: {agent.beta}')
    print(f'TIME: {times[-1]:.3f}')

    regrets = np.array(regrets)
    times = np.array(times)
    regrets_non_neg = np.where(regrets < 0, 0, regrets)
    cumulative_regrets = np.cumsum(regrets_non_neg)
    save_checkpoint(T, cumulative_regrets, times, **kw)      
    return cumulative_regrets

def main():

    parser = argparse.ArgumentParser()

    parser.add_argument('--T', type=int, default=5000)
    parser.add_argument('--N', type=int, default=1642)
    parser.add_argument('--K', type=int, default=5)
    parser.add_argument('--dim', type=int, default=25)
    parser.add_argument('--lam', type=float, default=1.0)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--alg', default='UCBCLB') # C3UCB UCBCLB UCBCCA CLogUCB 
    parser.add_argument('--oracle', default='optimal') # optimal
    parser.add_argument('--B', type=float, default=1.0)
    parser.add_argument('--C', type=float, default=1.0)

    args = parser.parse_args()
    
    T = args.T
    N = args.N
    K = args.K
    dim = args.dim
    lam = args.lam
    seed = args.seed
    algorithm = args.alg
    oracle_type = args.oracle
    B = args.B
    C = args.C
    kappa = np.exp(B)/((1+np.exp(B))**2)
    
    kw = {'T': T, 'N': N, 'K': K, 'd': dim, 
          'lam':lam, 'seed': seed,
          'algorithm': algorithm, 'oracle_type': oracle_type, 'kappa':kappa, 'B':B, 'C':C}

    run_experiment(**kw)

if __name__ == "__main__":

    main()